﻿using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;

namespace FeedbackNetwork.network.layer
{
    public class Relu : Layer
    {
        public Relu():base(1, 1)
        {

        }
        public override FloatTensor[] Backward(FloatTensor[] input)
        {
            float[] data = input[0].GetData();
            for (int i = 0; i < data.Length; i++)
            {
                if (data[i] <= 0) data[i] = 0f;
            }
            input[0].SetData(data, input[0].GetShape());
            return input;
        }

        public override FloatTensor Forward(FloatTensor input)
        {
            float[] data = input.GetData();
            for(int i = 0; i < data.Length; i++)
            {
                if (data[i] <= 0) data[i] = 0f;
            }
            input.SetData(data, input.GetShape());
            return input;
        }

        public override string ToString()
        {
            return new string("ReLU 层");
        }
    }
}
